import torch
from torch import nn


# 多分类交叉熵损失，使用nn.CrossEntropyLoss()实现。nn.CrossEntropyLoss()=softmax + 损失计算
def test1():
    # 设置真实值: 可以是热编码后的结果也可以不进行热编码
    # y_true = torch.tensor([[0, 1, 0], [0, 0, 1]], dtype=torch.float32)
    # 注意的类型必须是64位整型数据
    y_true = torch.tensor([1, 2], dtype=torch.int64)
    y_pred = torch.tensor([[0.2, 0.6, 0.2], [0.1, 0.8, 0.1]], dtype=torch.float32)
    # 实例化交叉熵损失
    loss = nn.CrossEntropyLoss()
    # 计算损失结果
    my_loss = loss(y_pred, y_true).numpy()
    print('loss:', my_loss)


# 二分类交叉熵损失，使用nn.BCELoss()实现
def test2():
    # 1 设置真实值和预测值
    # 预测值是sigmoid输出的结果
    y_pred = torch.tensor([0.6901, 0.5459, 0.2469], requires_grad=True)
    y_true = torch.tensor([0, 1, 0], dtype=torch.float32)
    # 2 实例化二分类交叉熵损失
    criterion = nn.BCELoss()
    # 3 计算损失
    my_loss = criterion(y_pred, y_true).detach().numpy()
    print('loss：', my_loss)


# 回归任务损失，MAE损失，使用nn.L1Loss()实现
# 计算算inputs与target之差的绝对值
def test3():
    # 1 设置真实值和预测值
    y_pred = torch.tensor([1.0, 1.0, 1.9], requires_grad=True)
    y_true = torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32)
    # 2 实例MAE损失对象
    loss = nn.L1Loss()
    # 3 计算损失
    my_loss = loss(y_pred, y_true).detach().numpy()
    print('loss:', my_loss)


# 回归任务损失，MSE损失，使用nn.MSELoss()实现
def test4():
    # 1 设置真实值和预测值
    y_pred = torch.tensor([1.0, 1.0, 1.9], requires_grad=True)
    y_true = torch.tensor([2.0, 2.0, 2.0], dtype=torch.float32)
    # 2 实例MSE损失对象
    loss = nn.MSELoss()
    # 3 计算损失
    my_loss = loss(y_pred, y_true).detach().numpy()
    print('myloss:', my_loss)

# 回归任务损失，SmoothL1损失，使用nn.SmoothL1Loss()实现
def test5():
    # 1 设置真实值和预测值
    y_true = torch.tensor([0, 3])
    y_pred = torch.tensor([0.6, 0.4], requires_grad=True)
    # 2 实例化smoothL1损失对象
    loss = nn.SmoothL1Loss()
    # 3 计算损失
    my_loss = loss(y_pred, y_true).detach().numpy()
    print('loss:', my_loss)


if __name__ == '__main__':
    # 分类损失-多分类交叉熵损失
    test1()

    # 分类损失-二分类交叉熵损失
    # test2()

    # 回归损失-MAE损失
    # test3()

    # 回归损失-MSE损失
    # test4()

    # 回归损失-SmoothL1损失
    test5()
